{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Convolutional Generative Adversarial Networks\n",
    "\n",
    "In [our introduction to generative adversarial networks (GANs)](./gan-intro.ipynb), \n",
    "we introduced the basic ideas behind how GANs work.\n",
    "We showed that they can draw samples from some simple, easy-to-sample distribution,\n",
    "like a uniform or normal distribution, \n",
    "and transform them into samples that appear to match the distribution of some data set. \n",
    "And while our example of matching a 2D Gaussian distribution got the point across, it's not especially exciting.\n",
    "\n",
    "In this notebook, we'll demonstrate how you can use GANs \n",
    "to generate photorealistic images. \n",
    "We'll be basing our models on the deep convolutional GANs introduced in [this paper](https://arxiv.org/abs/1511.06434). \n",
    "We'll borrow the convolutional architecture that have proven so successful for discriminative computer vision problems\n",
    "and show how via GANs, they can be leveraged to generate photorealistic images. \n",
    "\n",
    "In this tutorial, concentrate on the [LWF Face Dataset](http://vis-www.cs.umass.edu/lfw/), \n",
    "which contains roughly 13000 images of faces. \n",
    "By the end of the tutorial, you'll know how to generate photo-realistic images of your own, given any dataset of images. First, we'll the the preliminaries out of the way.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import os\n",
    "import matplotlib as mpl\n",
    "import tarfile\n",
    "import matplotlib.image as mpimg\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import mxnet as mx\n",
    "from mxnet import gluon\n",
    "from mxnet import ndarray as nd\n",
    "from mxnet.gluon import nn, utils\n",
    "from mxnet import autograd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set training parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "epochs = 2 # Set low by default for tests, set higher when you actually run this code.\n",
    "batch_size = 64\n",
    "latent_z_size = 100\n",
    "\n",
    "use_gpu = True\n",
    "ctx = mx.gpu() if use_gpu else mx.cpu()\n",
    "\n",
    "lr = 0.0002\n",
    "beta1 = 0.5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download and preprocess the LWF Face Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lfw_url = 'http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz'\n",
    "data_path = 'lfw_dataset'\n",
    "if not os.path.exists(data_path):\n",
    "    os.makedirs(data_path)\n",
    "    data_file = utils.download(lfw_url)\n",
    "    with tarfile.open(data_file) as tar:\n",
    "        tar.extractall(path=data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we resize images to size $64\\times64$. Then, we normalize all pixel values to the $[-1, 1]$ range."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "target_wd = 64\n",
    "target_ht = 64\n",
    "img_list = []\n",
    "\n",
    "def transform(data, target_wd, target_ht):\n",
    "    # resize to target_wd * target_ht\n",
    "    data = mx.image.imresize(data, target_wd, target_ht)\n",
    "    # transpose from (target_wd, target_ht, 3) \n",
    "    # to (3, target_wd, target_ht)\n",
    "    data = nd.transpose(data, (2,0,1))\n",
    "    # normalize to [-1, 1]\n",
    "    data = data.astype(np.float32)/127.5 - 1\n",
    "    # if image is greyscale, repeat 3 times to get RGB image.\n",
    "    if data.shape[0] == 1:\n",
    "        data = nd.tile(data, (3, 1, 1))\n",
    "    return data.reshape((1,) + data.shape)\n",
    "\n",
    "for path, _, fnames in os.walk(data_path):\n",
    "    for fname in fnames:\n",
    "        if not fname.endswith('.jpg'):\n",
    "            continue\n",
    "        img = os.path.join(path, fname)\n",
    "        img_arr = mx.image.imread(img)\n",
    "        img_arr = transform(img_arr, target_wd, target_ht)\n",
    "        img_list.append(img_arr)\n",
    "train_data = mx.io.NDArrayIter(data=nd.concatenate(img_list), batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Visualize 4 images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def visualize(img_arr):\n",
    "    plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))\n",
    "    plt.axis('off')\n",
    "\n",
    "for i in range(4):\n",
    "    plt.subplot(1,4,i+1)\n",
    "    visualize(img_list[i + 10][0])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the networks\n",
    "\n",
    "The core to the DCGAN architecture uses a standard CNN architecture on the discriminative model. For the generator,\n",
    "convolutions are replaced with upconvolutions, so the representation at each layer of the generator is actually successively larger, as it mapes from a low-dimensional latent vector onto a high-dimensional image.\n",
    "\n",
    "* Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).\n",
    "\n",
    "* Use batch normalization in both the generator and the discriminator.\n",
    "\n",
    "* Remove fully connected hidden layers for deeper architectures.\n",
    "\n",
    "* Use ReLU activation in generator for all layers except for the output, which uses Tanh.\n",
    "\n",
    "* Use LeakyReLU activation in the discriminator for all layers.\n",
    "\n",
    "![](../img/dcgan.png \"DCGAN Architecture\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# build the generator\n",
    "nc = 3\n",
    "ngf = 64\n",
    "netG = nn.Sequential()\n",
    "with netG.name_scope():\n",
    "    # input is Z, going into a convolution\n",
    "    netG.add(nn.Conv2DTranspose(ngf * 8, 4, 1, 0, use_bias=False))\n",
    "    netG.add(nn.BatchNorm())\n",
    "    netG.add(nn.Activation('relu'))\n",
    "    # state size. (ngf*8) x 4 x 4\n",
    "    netG.add(nn.Conv2DTranspose(ngf * 4, 4, 2, 1, use_bias=False))\n",
    "    netG.add(nn.BatchNorm())\n",
    "    netG.add(nn.Activation('relu'))\n",
    "    # state size. (ngf*8) x 8 x 8\n",
    "    netG.add(nn.Conv2DTranspose(ngf * 2, 4, 2, 1, use_bias=False))\n",
    "    netG.add(nn.BatchNorm())\n",
    "    netG.add(nn.Activation('relu'))\n",
    "    # state size. (ngf*8) x 16 x 16\n",
    "    netG.add(nn.Conv2DTranspose(ngf, 4, 2, 1, use_bias=False))\n",
    "    netG.add(nn.BatchNorm())\n",
    "    netG.add(nn.Activation('relu'))\n",
    "    # state size. (ngf*8) x 32 x 32\n",
    "    netG.add(nn.Conv2DTranspose(nc, 4, 2, 1, use_bias=False))\n",
    "    netG.add(nn.Activation('tanh'))\n",
    "    # state size. (nc) x 64 x 64\n",
    "\n",
    "# build the discriminator\n",
    "ndf = 64\n",
    "netD = nn.Sequential()\n",
    "with netD.name_scope():\n",
    "    # input is (nc) x 64 x 64\n",
    "    netD.add(nn.Conv2D(ndf, 4, 2, 1, use_bias=False))\n",
    "    netD.add(nn.LeakyReLU(0.2))\n",
    "    # state size. (ndf) x 32 x 32\n",
    "    netD.add(nn.Conv2D(ndf * 2, 4, 2, 1, use_bias=False))\n",
    "    netD.add(nn.BatchNorm())\n",
    "    netD.add(nn.LeakyReLU(0.2))\n",
    "    # state size. (ndf) x 16 x 16\n",
    "    netD.add(nn.Conv2D(ndf * 4, 4, 2, 1, use_bias=False))\n",
    "    netD.add(nn.BatchNorm())\n",
    "    netD.add(nn.LeakyReLU(0.2))\n",
    "    # state size. (ndf) x 8 x 8\n",
    "    netD.add(nn.Conv2D(ndf * 8, 4, 2, 1, use_bias=False))\n",
    "    netD.add(nn.BatchNorm())\n",
    "    netD.add(nn.LeakyReLU(0.2))\n",
    "    # state size. (ndf) x 4 x 4\n",
    "    netD.add(nn.Conv2D(1, 4, 1, 0, use_bias=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Loss Function and Optimizer\n",
    "We use binary cross-entropy as our loss function and use the Adam optimizer. We initialize the network's parameters by sampling from a normal distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# loss\n",
    "loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()\n",
    "\n",
    "# initialize the generator and the discriminator\n",
    "netG.initialize(mx.init.Normal(0.02), ctx=ctx)\n",
    "netD.initialize(mx.init.Normal(0.02), ctx=ctx)\n",
    "\n",
    "# trainer for the generator and the discriminator\n",
    "trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})\n",
    "trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Loop\n",
    "We recommend thst you use a GPU for training this model. After a few epochs, we can see human-face-like images are generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "import time\n",
    "import logging\n",
    "\n",
    "real_label = nd.ones((batch_size,), ctx=ctx)\n",
    "fake_label = nd.zeros((batch_size,),ctx=ctx)\n",
    "\n",
    "def facc(label, pred):\n",
    "    pred = pred.ravel()\n",
    "    label = label.ravel()\n",
    "    return ((pred > 0.5) == label).mean()\n",
    "metric = mx.metric.CustomMetric(facc)\n",
    "\n",
    "stamp =  datetime.now().strftime('%Y_%m_%d-%H_%M')\n",
    "logging.basicConfig(level=logging.DEBUG)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    tic = time.time()\n",
    "    btic = time.time()\n",
    "    train_data.reset()\n",
    "    iter = 0\n",
    "    for batch in train_data:\n",
    "        ############################\n",
    "        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n",
    "        ###########################\n",
    "        data = batch.data[0].as_in_context(ctx)\n",
    "        latent_z = mx.nd.random_normal(0, 1, shape=(batch_size, latent_z_size, 1, 1), ctx=ctx)\n",
    "\n",
    "        with autograd.record():\n",
    "            # train with real image\n",
    "            output = netD(data).reshape((-1, 1))\n",
    "            errD_real = loss(output, real_label)\n",
    "            metric.update([real_label,], [output,])\n",
    "\n",
    "            # train with fake image\n",
    "            fake = netG(latent_z)\n",
    "            output = netD(fake.detach()).reshape((-1, 1))\n",
    "            errD_fake = loss(output, fake_label)\n",
    "            errD = errD_real + errD_fake\n",
    "            errD.backward()\n",
    "            metric.update([fake_label,], [output,])\n",
    "\n",
    "        trainerD.step(batch.data[0].shape[0])\n",
    "\n",
    "        ############################\n",
    "        # (2) Update G network: maximize log(D(G(z)))\n",
    "        ###########################\n",
    "        with autograd.record():\n",
    "            fake = netG(latent_z)\n",
    "            output = netD(fake).reshape((-1, 1))\n",
    "            errG = loss(output, real_label)\n",
    "            errG.backward()\n",
    "\n",
    "        trainerG.step(batch.data[0].shape[0])\n",
    "\n",
    "        # Print log infomation every ten batches\n",
    "        if iter % 10 == 0:\n",
    "            name, acc = metric.get()\n",
    "            logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))\n",
    "            logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d' \n",
    "                     %(nd.mean(errD).asscalar(), \n",
    "                       nd.mean(errG).asscalar(), acc, iter, epoch))\n",
    "        iter = iter + 1\n",
    "        btic = time.time()\n",
    "\n",
    "    name, acc = metric.get()\n",
    "    metric.reset()\n",
    "    # logging.info('\\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))\n",
    "    # logging.info('time: %f' % (time.time() - tic))\n",
    "\n",
    "    # Visualize one generated image for each epoch\n",
    "    # fake_img = fake[0]\n",
    "    # visualize(fake_img)\n",
    "    # plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results\n",
    "Given a trained generator, we can generate some images of faces."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_image = 8\n",
    "for i in range(num_image):\n",
    "    latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)\n",
    "    img = netG(latent_z)\n",
    "    plt.subplot(2,4,i+1)\n",
    "    visualize(img[0])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also interpolate along the manifold between images by interpolating linearly between points in the latent space and visualizing the corresponding images. We can see that small changes in the latent space results in smooth changes in generated images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_image = 12\n",
    "latent_z = mx.nd.random_normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)\n",
    "step = 0.05\n",
    "for i in range(num_image):\n",
    "    img = netG(latent_z)\n",
    "    plt.subplot(3,4,i+1)\n",
    "    visualize(img[0])\n",
    "    latent_z += 0.05\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
